Skip to content

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Nov 7, 2025

What does this PR do?

Type of change: ?
New feature

Overview: ?
This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs.

Key Features:

  • Skip softmax support
  • Sparse attention config
  • Extensible method registry for future sparse attention algorithms
  • HuggingFace Transformers integration
  • Phase-aware thresholds (separate prefill/decode)

Design doc

Usage

import torch
import modelopt.torch.sparsity.attention_sparsity as mts
from transformers import AutoModelForCausalLM

# Load model (must use eager attention for softmax patching)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="eager",  # Required!
    torch_dtype=torch.bfloat16,
)

# Use pre-defined configuration
from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT
model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT)

Testing

Unit Test

pytest tests/unit/torch/sparsity/attention_sparsity -v
pytest tests/gpu/torch/sparsity/attention_sparsity -v
pytest tests/examples/llm_sparsity/attention_sparsity -v

ALL PASSED.

Accuracy

Benchmark: MMLU
Model: Qwen/Qwen3-4B
Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT

MMLU
BF16 69.96
SKIP_SOFTMAX_DEFAULT 69.86

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@kaix-nv kaix-nv requested a review from a team as a code owner November 7, 2025 07:53
@kaix-nv kaix-nv requested review from realAsma and removed request for realAsma November 7, 2025 07:53
@codecov
Copy link

codecov bot commented Nov 7, 2025

Codecov Report

❌ Patch coverage is 89.67254% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.95%. Comparing base (fa84955) to head (cd6fce2).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ch/sparsity/attention_sparsity/sparse_attention.py 71.42% 16 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 93.02% 9 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 90.90% 8 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 88.88% 3 Missing ⚠️
modelopt/torch/sparsity/attention_sparsity/mode.py 90.32% 3 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 95.91% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #527      +/-   ##
==========================================
+ Coverage   74.64%   74.95%   +0.31%     
==========================================
  Files         183      192       +9     
  Lines       18542    18939     +397     
==========================================
+ Hits        13840    14196     +356     
- Misses       4702     4743      +41     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch 4 times, most recently from 54bfe2c to 0ce1376 Compare November 8, 2025 03:31
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from fc9d285 to 5d027e0 Compare November 11, 2025 23:44
@kaix-nv kaix-nv changed the title [2/n] Add Core Sparse Attention Infrastructure [OMNIML-2852][2/n] Add Core Sparse Attention Infrastructure Nov 12, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2852][2/n] Add Core Sparse Attention Infrastructure [OMNIML-2852] [2/n] Add Core Sparse Attention Infrastructure Nov 12, 2025
@cjluo-nv
Copy link
Collaborator

Hi @kaix-nv could you further split this code change? This PR has 3000+ lines of code change and many file moves

@kevalmorabia97 kevalmorabia97 removed the request for review from RalphMao December 1, 2025 19:07


# Create registry for sparse attention modules
SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule)
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a single registry for all Sparsity algorithms and modes and then use top-level mts.sparsify(model, mode=...) so all algorithms (e.g. weight or attention sparsify) are invoked by single shared API instead of separate API per algorithm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good advice. I'll submit a follow-up PR later.

Copy link

@jy-yuan jy-yuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the overall architecture!

Comment on lines +193 to +194
total_blocks = (
num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that means rows==columns? Which means we only have causal in self-attention, not cross-attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's for causal attention in prefill.

"--backend",
type=str,
default="pytorch",
choices=["pytorch", "triton"],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "triton" a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Comment on lines 351 to 352
method = getattr(module, "_method", "unknown")
threshold = getattr(module, "_threshold", "N/A")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do SparseAttentionModule have _method or _threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print_sparse_attention_summary isn’t used in this PR, I’ve removed it. It will be introduced in the next PR.

Comment on lines +188 to +210
def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]):
"""Restore sparse attention state from state dict.

Args:
model: Model with sparse attention modules
state_dict: Saved state dictionary
"""
for name, module in model.named_modules():
if isinstance(module, SparseAttentionModule):
module_name = get_unwrapped_name(name, model)
if module_name in state_dict:
module_state = state_dict[module_name]

# Restore method and config
if "method" in module_state:
module._method = module_state["method"]
if "method_config" in module_state:
# Restore config attributes
for key, val in module_state["method_config"].items():
setattr(module, f"_{key}", val)

# Re-setup with restored config
module._setup()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need add test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_restore_sparse_attention_model covers the test for this func.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from cd6fce2 to 0ca4d20 Compare December 8, 2025 21:32
@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 8, 2025

@kevalmorabia97 I've addressed the review suggestions. Could you please review and approve the PR so I can move forward with the subsequent PRs? Thanks.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from 0ca4d20 to 02182f8 Compare December 9, 2025 00:05
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 9, 2025

Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it
I've asked @jy-yuan to review since he's very familiar with the core logic.
@jy-yuan Please approve if you think the PR is in good shape. Thanks.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from 02182f8 to 8acf333 Compare December 9, 2025 15:41
@kaix-nv kaix-nv requested a review from a team as a code owner December 9, 2025 15:41
@realAsma
Copy link
Contributor

Looks great!

Should we have a simpler high-level usage which aligns with mtq?

# Use pre-defined configuration
model = mts.sparsify(model, mts.SPARSE_ATTEN_SKIP_SOFTMAX_CFG)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants